{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPT-2 private inference with SPU\n",
    "\n",
    "In lab [Neural Network with SPU](./nn_with_spu.ipynb), we have demonstrated how to use SecretFlow/SPU to train a Neural Network model privately.\n",
    "\n",
    "In this lab, we showcase how to run private inference on a pre-trained [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) model for text generation with SPU.\n",
    "\n",
    "First, we show how to use JAX and the Hugging Face Transformers library for text generation with the pre-trained GPT-2 model. After that, we show how to use SPU for private text generation with minor modifications to the plaintext counterpart. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> This tutorial may need more resources than 16c48g."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text generation using GPT-2 with JAX/FLAX\n",
    "### Install the transformers library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "!{sys.executable} -m pip install transformers[flax]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">The JAX version required by transformers is not satisfied with SPU. But it's ok to run with the conflicted JAX with SPU in this example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the pre-trained GPT-2 Model\n",
    "\n",
    "Please refer to this [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "pretrained_model = FlaxGPT2LMHeadModel.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the text generation function\n",
    "\n",
    "\n",
    "We use a [greedy search strategy](https://huggingface.co/blog/how-to-generate) for text generation here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def text_generation(input_ids, params):\n",
    "    config = GPT2Config()\n",
    "    model = FlaxGPT2LMHeadModel(config=config)\n",
    "\n",
    "    for _ in range(10):\n",
    "        outputs = model(input_ids=input_ids, params=params)\n",
    "        next_token_logits = outputs[0][0, -1, :]\n",
    "        next_token = jnp.argmax(next_token_logits)\n",
    "        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)\n",
    "    return input_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run text generation on CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-15 17:07:55.627043: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "2023-06-15 17:07:55.627112: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "2023-06-15 17:07:55.627118: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------------------------------------------\n",
      "Run on CPU:\n",
      "-----------------------------------------------------------------\n",
      "I enjoy walking with my cute dog, but I'm not sure if I'll ever\n",
      "-----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "inputs_ids = tokenizer.encode(\n",
    "    'I enjoy walking with my cute dog', return_tensors='jax'\n",
    ")\n",
    "outputs_ids = text_generation(inputs_ids, pretrained_model.params)\n",
    "\n",
    "print('-'*65 + '\\nRun on CPU:\\n' + '-'*65)\n",
    "print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))\n",
    "print('-'*65)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we generate 10 tokens. Keep the generated text in mind, we are going to generate text on SPU in the next step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Run text generation on SPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-15 17:08:14,157\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(pid=2109508)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2109408)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2121303)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2121304)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(pid=2121301)\u001b[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n",
      "\u001b[2m\u001b[36m(_run pid=2109508)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_run pid=2109408)\u001b[0m [2023-06-15 17:08:24.221] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(['alice', 'bob', 'carol'], address='local')\n",
    "\n",
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n",
    "conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])\n",
    "conf['runtime_config']['fxp_exp_mode']=1\n",
    "conf['runtime_config']['experimental_disable_mmul_split']=True\n",
    "spu = sf.SPU(conf)\n",
    "\n",
    "def get_model_params():\n",
    "    pretrained_model = FlaxGPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
    "    return pretrained_model.params\n",
    "\n",
    "def get_token_ids():\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "    return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')\n",
    "\n",
    "\n",
    "model_params = alice(get_model_params)()\n",
    "input_token_ids = bob(get_token_ids)()\n",
    "\n",
    "device = spu\n",
    "model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)\n",
    "\n",
    "output_token_ids = spu(text_generation)(input_token_ids_, model_params_)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check the SPU output\n",
    "\n",
    "As you can see, it's very easy to run GPT-2 inference on SPU. Now let's reveal the generated text from SPU program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(_spu_compile pid=2109408)\u001b[0m 2023-06-15 17:09:12.722333: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "\u001b[2m\u001b[36m(_spu_compile pid=2109408)\u001b[0m 2023-06-15 17:09:12.722414: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst\n",
      "\u001b[2m\u001b[36m(_spu_compile pid=2109408)\u001b[0m 2023-06-15 17:09:12.722421: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(SPURuntime(device_id=None, party=bob) pid=2121303)\u001b[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n",
      "\u001b[2m\u001b[36m(SPURuntime(device_id=None, party=alice) pid=2121301)\u001b[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n",
      "\u001b[2m\u001b[36m(SPURuntime(device_id=None, party=carol) pid=2121304)\u001b[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n",
      "-----------------------------------------------------------------\n",
      "Run on SPU:\n",
      "-----------------------------------------------------------------\n",
      "I enjoy walking with my cute dog, but I'm not sure if I'll ever\n",
      "-----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "outputs_ids = sf.reveal(output_token_ids)\n",
    "print('-'*65 + '\\nRun on SPU:\\n' + '-'*65)\n",
    "print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))\n",
    "print('-'*65)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the generated text from SPU is exactly same as the generated text from CPU!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the end of the lab."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  },
  "vscode": {
   "interpreter": {
    "hash": "db45a4cb4cd37a8de684dfb7fcf899b68fccb8bd32d97c5ad13e5de1245c0986"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
